Model interpretation for Visual Question Answering

In this notebook we demonstrate how to apply model interpretability algorithms from captum library on VQA models. More specifically we explain model predictions by applying integrated gradients on a small sample of image-question pairs. More details about Integrated gradients can be found in the original paper: https://arxiv.org/pdf/1703.01365.pdf

As a reference VQA model we use the following open source implementation: https://github.com/Cyanogenoid/pytorch-vqa

In [2]:
import os, sys
import numpy as np

# Clone PyTorch VQA project from: https://github.com/Cyanogenoid/pytorch-vqa and add to your filepath
# Replace <PROJECT-DIR> placeholder with your project directory path
sys.path.append(os.path.realpath('<PROJECT-DIR>/pytorch-vqa'))

# Clone PyTorch Resnet model from: https://github.com/Cyanogenoid/pytorch-resnet and add to your filepath
# We can also use standard resnet model from torchvision package, however the model from `pytorch-resnet` 
# is slightly different from the original resnet model and performs better on this particular VQA task
sys.path.append(os.path.realpath('<PROJECT-DIR>/pytorch-resnet'))
In [3]:
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

import resnet  # from pytorch-resnet

import matplotlib.pyplot as plt
from PIL import Image

from model import Net, apply_attention, tile_2d_over_nd # from pytorch-vqa
from utils import get_transform # from pytorch-vqa

from captum.attr import IntegratedGradients
from captum.attr import InterpretableEmbeddingBase, TokenReferenceBase
from captum.attr import visualization, configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Loading VQA model

In [5]:
saved_state = torch.load('models/2017-08-04_00.55.19.pth', map_location=device)

# reading vocabulary from saved model
vocab = saved_state['vocab']

# reading word tokens from saved model
token_to_index = vocab['question']

# reading answers from saved model
answer_to_index = vocab['answer']

num_tokens = len(token_to_index) + 1

# reading answer classes from the vocabulary
answer_words = ['unk'] * len(answer_to_index)
for w, idx in answer_to_index.items():
    answer_words[idx]=w
In [6]:
# Load the predefined model
# `device_ids` contains a list of GPU ids which are used for paralelization supported by `DataParallel`
vqa_net = torch.nn.DataParallel(Net(num_tokens), device_ids=[0,1])
vqa_net.load_state_dict(saved_state['weights'])
vqa_net.to(device)
vqa_net.eval()
/data/users/vivekm/pytorch-vqa/model.py:90: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.
  init.xavier_uniform(w)
/data/users/vivekm/pytorch-vqa/model.py:86: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.
  init.xavier_uniform(self.embedding.weight)
/data/users/vivekm/pytorch-vqa/model.py:44: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.
  init.xavier_uniform(m.weight)
Out[6]:
DataParallel(
  (module): Net(
    (text): TextProcessor(
      (embedding): Embedding(15193, 300, padding_idx=0)
      (drop): Dropout(p=0.5, inplace=False)
      (tanh): Tanh()
      (lstm): LSTM(300, 1024)
    )
    (attention): Attention(
      (v_conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (q_lin): Linear(in_features=1024, out_features=512, bias=True)
      (x_conv): Conv2d(512, 2, kernel_size=(1, 1), stride=(1, 1))
      (drop): Dropout(p=0.5, inplace=False)
      (relu): ReLU(inplace=True)
    )
    (classifier): Classifier(
      (drop1): Dropout(p=0.5, inplace=False)
      (lin1): Linear(in_features=5120, out_features=1024, bias=True)
      (relu): ReLU()
      (drop2): Dropout(p=0.5, inplace=False)
      (lin2): Linear(in_features=1024, out_features=3000, bias=True)
    )
  )
)

Converting string question into a tensor. encode_question function is similar to original implementation of encode_question method in pytorch-vqa source code. https://github.com/Cyanogenoid/pytorch-vqa/blob/master/data.py#L110

In [7]:
def encode_question(question):
    """ Turn a question into a vector of indices and a question length """
    question_arr = question.lower().split()
    vec = torch.zeros(len(question_arr), device=device).long()
    for i, token in enumerate(question_arr):
        index = token_to_index.get(token, 0)
        vec[i] = index
    return vec, torch.tensor(len(question_arr), device=device)

Defining end-to-end VQA model

Original saved model does not have image network's (resnet's) layers attached to it. We attach it in the below cell using forward-hook. The rest of the model is identical to the original definition of the model: https://github.com/Cyanogenoid/pytorch-vqa/blob/master/model.py#L48

In [8]:
class ResNetLayer4(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.r_model = resnet.resnet152(pretrained=True)
        self.r_model.eval()
        self.r_model.to(device)
        self.buffer = None

        def save_output(module, input, output):
            self.buffer = output

        self.r_model.layer4.register_forward_hook(save_output)

    def forward(self, x):
        self.r_model(x)
        return self.buffer

class VQA_Resnet_Model(Net):
    def __init__(self, embedding_tokens):
        super().__init__(embedding_tokens)
        self.resnet_layer4 = ResNetLayer4()
    
    def forward(self, v, q, q_len):
        q = self.text(q, list(q_len.data))
        v = self.resnet_layer4(v)

        v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8)

        a = self.attention(v, q)
        v = apply_attention(v, a)

        combined = torch.cat([v, q], dim=1)
        answer = self.classifier(combined)
        return answer

Updating weights from the saved model and removing the old model from the memory.

In [9]:
vqa_resnet = VQA_Resnet_Model(vqa_net.module.text.embedding.num_embeddings)
# `device_ids` contains a list of GPU ids which are used for paralelization supported by `DataParallel`
vqa_resnet = torch.nn.DataParallel(vqa_resnet, device_ids=[0,1])

# saved vqa model's parameters
partial_dict = vqa_net.state_dict()

state = vqa_resnet.state_dict()
state.update(partial_dict)
vqa_resnet.load_state_dict(state)

vqa_resnet.to(device)
vqa_resnet.eval()

# This is original VQA model without resnet. Removing it, since we do not need it
del vqa_net
In [10]:
image_size = 448  # scale image to given size and center
central_fraction = 1.0

transform = get_transform(image_size, central_fraction=central_fraction)
    
def image_to_features(img):
    img_transformed = transform(img)
    img_batch = img_transformed.unsqueeze(0).to(device)
    return img_batch
/home/vivekm/local/anaconda3/envs/captum_new/lib/python3.7/site-packages/torchvision/transforms/transforms.py:210: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +

In order to explain text features, we introduce interpretable embedding layers which allows access word embeddings and generate meaningful attributions for each embedding dimension.

configure_interpretable_embedding_layer function separates embedding layer from the model and precomputes word embeddings in advance. The embedding layer of our model is then being replaced by an Interpretable Embedding Layer which wraps original embedding layer and takes word embedding vectors as inputs of the forward function. This allows to generate baselines for word embeddings and compute attributions for each embedding dimension.

Note: After finishing interpretation it is important to call remove_interpretable_embedding_layer which removes the Interpretable Embedding Layer that we added for interpretation purposes and sets the original embedding layer back in the model.

In [11]:
interpretable_embedding = configure_interpretable_embedding_layer(vqa_resnet, 'module.text.embedding')
/data/users/vivekm/captum/captum/attr/_models/base.py:133: UserWarning: In order to make embedding layers more interpretable they will
        be replaced with an interpretable embedding layer which wraps the
        original embedding layer and takes word embedding vectors as inputs of
        the forward function. This allows to generate baselines for word
        embeddings and compute attributions for each embedding dimension.
        The original embedding layer must be set
        back by calling `remove_interpretable_embedding_layer` function
        after model interpretation is finished.
  after model interpretation is finished."""

Creating reference aka baseline / background for questions. This is specifically necessary for baseline-based model interpretability algorithms. In this case for integrated gradients. More details can be found in the original paper: https://arxiv.org/pdf/1703.01365.pdf

In [12]:
PAD_IND = token_to_index['pad']
token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)
In [13]:
# this is necessary for the backpropagation of RNNs models in eval mode
torch.backends.cudnn.enabled=False

Creating an instance of integrated gradients. It will be used to intepret model's predictions.

In [14]:
ig = IntegratedGradients(vqa_resnet)

Defining a few test images for model intepretation purposes

In [15]:
images = ['./img/vqa/siamese.jpg',
          './img/vqa/elephant.jpg',
          './img/vqa/zebra.jpg']
In [16]:
def vqa_resnet_interpret(image_filename, questions, targets):
    img = Image.open(image_filename).convert('RGB')
    original_image = transforms.Compose([transforms.Scale(int(image_size / central_fraction)),
                                   transforms.CenterCrop(image_size), transforms.ToTensor()])(img) 
    
    image_features = image_to_features(img).requires_grad_().to(device)
    for question, target in zip(questions, targets):
        q, q_len = encode_question(question)

        q_input_embedding = interpretable_embedding.indices_to_embeddings(q).unsqueeze(0)

        # Making prediction. The output of prediction will be visualized later
        ans = vqa_resnet(image_features, q_input_embedding, q_len.unsqueeze(0))
        pred, answer_idx = F.softmax(ans, dim=1).data.cpu().max(dim=1)
        
        # generate reference for each sample
        q_reference_indices = token_reference.generate_reference(q_len.item(), 
                                                                 device=device).unsqueeze(0)
        q_reference = interpretable_embedding.indices_to_embeddings(q_reference_indices).to(device)
        attributions, delta = ig.attribute(inputs=(image_features, q_input_embedding),
                                            baselines=(image_features * 0.0, q_reference),
                                            target=answer_idx,
                                            additional_forward_args=q_len.unsqueeze(0),
                                            n_steps=30)
        # Visualize text attributions
        text_attributions_norm = attributions[1].sum(dim=2).squeeze(0).norm()
        vis_data_records = [visualization.VisualizationDataRecord(
                                attributions[1].sum(dim=2).squeeze(0) / text_attributions_norm,
                                pred[0].item(),
                                answer_words[ answer_idx ],
                                answer_words[ answer_idx ],
                                target,
                                attributions[1].sum(),       
                                question.split(),
                                0.0)]
        visualization.visualize_text(vis_data_records)

        # visualize image attributions
        original_im_mat = np.transpose(original_image.cpu().detach().numpy(), (1, 2, 0))
        attr = np.transpose(attributions[0].squeeze(0).cpu().detach().numpy(), (1, 2, 0))
        
        visualization.visualize_image_attr_multiple(attr, original_im_mat,
                                                    ["original_image", "heat_map"], ["all", "absolute_value"], 
                                                    titles=["Original Image", "Attribution Magnitude"], show_colorbar=True)
        print('Text Contributions: ', attributions[1].sum().item())
        print('Image Contributions: ', attributions[0].sum().item())
        print('Total Contribution: ', attributions[0].sum().item() + attributions[1].sum().item())
In [17]:
# the index of image in the test set. Please, change it if you want to play with different test images/samples.
image_idx = 1 # elephant
vqa_resnet_interpret(images[image_idx], [
    "what is on the picture",
    "what color is the elephant",
    "where is the elephant"
], ['elephant', 'gray', 'zoo'])
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
elephantelephant (0.55)elephant8.21 what is on the picture
Text Contributions:  8.207755088806152
Image Contributions:  18.176225662231445
Total Contribution:  26.383980751037598
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
graygray (0.78)gray6.92 what color is the elephant
Text Contributions:  6.91782808303833
Image Contributions:  10.076950073242188
Total Contribution:  16.994778156280518
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
zoozoo (0.18)zoo18.23 where is the elephant
Text Contributions:  18.232120513916016
Image Contributions:  24.851558685302734
Total Contribution:  43.08367919921875
In [18]:
import IPython
# Above cell generates an output similar to this:
IPython.display.Image(filename='img/vqa/elephant_attribution.jpg')
Out[18]:
In [19]:
image_idx = 0 # cat

vqa_resnet_interpret(images[image_idx], [
    "what is on the picture",
    "what color are the cat's eyes",
    "is the animal in the picture a cat or a fox",
    "what color is the cat",
    "how many ears does the cat have",
    "where is the cat"
], ['cat', 'blue', 'cat', 'white and brown', '2', 'at the wall'])
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
catcat (0.96)cat6.75 what is on the picture
Text Contributions:  6.752697944641113
Image Contributions:  4.206277847290039
Total Contribution:  10.958975791931152
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
blueblue (0.52)blue13.50 what color are the cat's eyes
Text Contributions:  13.495126724243164
Image Contributions:  -5.307694435119629
Total Contribution:  8.187432289123535
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
catcat (0.58)cat0.23 is the animal in the picture a cat or a fox
Text Contributions:  0.22824923694133759
Image Contributions:  4.890588283538818
Total Contribution:  5.118837520480156
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
brownbrown (0.16)white and brown11.51 what color is the cat
Text Contributions:  11.512341499328613
Image Contributions:  -1.2306742668151855
Total Contribution:  10.281667232513428
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
22 (0.91)222.32 how many ears does the cat have
Text Contributions:  22.316856384277344
Image Contributions:  -5.502903938293457
Total Contribution:  16.813952445983887
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
outsideoutside (0.31)at the wall16.41 where is the cat
Text Contributions:  16.41370964050293
Image Contributions:  -4.363155841827393
Total Contribution:  12.050553798675537
In [20]:
# Above cell generates an output similar to this:
IPython.display.Image(filename='img/vqa/siamese_attribution.jpg')
Out[20]:
In [21]:
image_idx = 2 # zebra

vqa_resnet_interpret(images[image_idx], [
    "what is on the picture",
    "what color are the zebras",
    "how many zebras are on the picture",
    "where are the zebras"
], ['zebra', 'black and white', '2', 'zoo'])
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
zebrazebra (0.60)zebra7.54 what is on the picture
Text Contributions:  7.54050350189209
Image Contributions:  11.20102310180664
Total Contribution:  18.74152660369873
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
black and whiteblack and white (0.79)black and white11.07 what color are the zebras
Text Contributions:  11.068336486816406
Image Contributions:  5.008208274841309
Total Contribution:  16.076544761657715
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
22 (0.72)223.89 how many zebras are on the picture
Text Contributions:  23.89345359802246
Image Contributions:  1.4363466501235962
Total Contribution:  25.329800248146057
Target LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
zoozoo (0.42)zoo24.57 where are the zebras
Text Contributions:  24.5710391998291
Image Contributions:  7.849122047424316
Total Contribution:  32.42016124725342
In [22]:
# Above cell generates an output similar to this:
IPython.display.Image(filename='img/vqa/zebra_attribution.jpg')
Out[22]:

As mentioned above, after we are done with interpretation, we have to remove Interpretable Embedding Layer and set the original embeddings layer back to the model.

In [23]:
remove_interpretable_embedding_layer(vqa_resnet, interpretable_embedding)